--- title: Using fastai for segmentation keywords: fastai sidebar: home_sidebar nb_path: "nbs/examples.fastai.segmentation.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
 
{% endraw %} {% raw %}
%matplotlib inline

from pathlib import Path
from drone_detector.processing.tiling import *
import os, sys
import geopandas as gpd
from fastai.vision.all import *
from drone_detector.engines.fastai.data import *
{% endraw %} {% raw %}
outpath = Path('../data/historic_map/processed/raster_tiles/')

fnames = [Path(outpath/f) for f in os.listdir(outpath)]

dls = SegmentationDataLoaders.from_label_func('../data/historic_map/', bs=16,
                                              codes=['Marshes'],
                                              fnames=fnames,
                                              label_func=partial(label_from_different_folder,
                                                                 original_folder='raster_tiles',
                                                                 new_folder='mask_tiles'),
                                              batch_tfms = [
                                                  *aug_transforms(max_rotate=0., max_warp=0.),
                                                  Normalize.from_stats(*imagenet_stats)
                                              ])
/opt/conda/lib/python3.9/site-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  ret = func(*args, **kwargs)
{% endraw %}

label_from_different_folder is a helper located in drone_detector.engines.fastai.data. That module also contains helpers to use with for instance multispectral images or time series of images.

{% raw %}
dls.show_batch(max_n=16)
{% endraw %}

Train basic U-Net, using pretrained Resnet50 as the encoder. to_fp16() tells our model to use half precision training, thus using less memory. Loss function is FocalLossFlat, and for segmentation we need to specify axis=1. Metrics are Dice and JaccardCoeff, fairly standard segmentation metrics.

{% raw %}
learn = unet_learner(dls, resnet50, pretrained=True, n_in=3, n_out=2,
                     metrics=[Dice(), JaccardCoeff()], loss_func=FocalLossFlat(axis=1)
                    ).to_fp16()
{% endraw %}

Search for a suitable learning rate.

{% raw %}
learn.lr_find()
SuggestedLRs(valley=4.365158383734524e-05)
{% endraw %}

Train the model for 2 epochs with encoder layers frozen and 10 epochs with all layers unfrozen.

{% raw %}
from fastai.callback.progress import ShowGraphCallback
learn.fine_tune(10, freeze_epochs=2, base_lr=1e-4, cbs=ShowGraphCallback)
epoch train_loss valid_loss dice jaccard_coeff time
0 0.224074 0.083597 0.017181 0.008665 00:13
1 0.189758 0.064810 0.598005 0.426539 00:11
epoch train_loss valid_loss dice jaccard_coeff time
0 0.044298 0.026398 0.747473 0.596772 00:11
1 0.037180 0.030933 0.729766 0.574512 00:11
2 0.031749 0.026626 0.774982 0.632629 00:11
3 0.031116 0.014586 0.859288 0.753291 00:11
4 0.027017 0.015750 0.863710 0.760114 00:11
5 0.023494 0.013853 0.857760 0.750945 00:11
6 0.021318 0.017538 0.852509 0.742932 00:11
7 0.019447 0.012881 0.885675 0.794809 00:11
8 0.018104 0.012061 0.889665 0.801259 00:11
9 0.016782 0.013043 0.885664 0.794792 00:11
{% endraw %}

Return to full precision.

{% raw %}
learn.to_fp32()
<fastai.learner.Learner at 0x7f16edcbf760>
{% endraw %}

Check results.

{% raw %}
learn.show_results(max_n=8)
{% endraw %}

Export the model to use later.

{% raw %}
learn.path = Path('../data/historic_map/models')
learn.export('resnet50_focalloss_swamps.pkl')
{% endraw %}

Some helper functions for inference, such as removing all resizing transforms.

{% raw %}
def label_func(fn):
    return str(fn).replace('raster_tiles', 'mask_tiles')

@patch 
def remove(self:Pipeline, t):
    for i,o in enumerate(self.fs):
        if isinstance(o, t.__class__): self.fs.pop(i)
            
@patch
def set_base_transforms(self:DataLoader):
    attrs = ['after_item', 'after_batch']
    for i, attr in enumerate(attrs):
        tfms = getattr(self, attr)
        for j, o in enumerate(tfms):
            if hasattr(o, 'size'):
                tfms.remove(o)
            setattr(self, attr, tfms)
{% endraw %}

Load learners and remove all resizing transforms.

{% raw %}
testlearn = load_learner('../data/historic_map/models/resnet50_focalloss_swamps.pkl', cpu=False)
testlearn.dls.valid.set_base_transforms()
{% endraw %}

The model is tested with 4 different map patches from different areas and sizes. Two of the images are from 1965 and two from 1984. Image sizes vary between 600x600 and 1500x1500 pixels.

{% raw %}
import PIL
def unet_predict(fn):
    image = np.array(PIL.Image.open(fn))
    mask = testlearn.predict(PILImage.create(image))[0].numpy()
    img = image
    img[:,:,0][mask==0] = 0
    img[:,:,1][mask==0] = 0
    img[:,:,2][mask==0] = 0
    img = PIL.Image.fromarray(img.astype(np.uint8))
    return img
{% endraw %} {% raw %}
test_images = [f'../data/historic_map/test_patches/{f}' for f in os.listdir('../data/historic_map/test_patches/')]
{% endraw %}

First result.

{% raw %}
patch_pred = unet_predict(test_images[0])

fig, axs = plt.subplots(1,2, figsize=(10,5),dpi=300)
for a in axs:
    a.set_yticks([])
    a.set_xticks([])
axs[0].imshow(PIL.Image.open(test_images[0]))
axs[1].imshow(patch_pred)
axs[0].set_title(test_images[0].split('/')[-1])
axs[1].set_title('Predicted marshes')
plt.show()
{% endraw %}

Second result

{% raw %}
patch_pred = unet_predict(test_images[1])

fig, axs = plt.subplots(1,2, figsize=(10,5),dpi=300)
for a in axs:
    a.set_yticks([])
    a.set_xticks([])
axs[0].imshow(PIL.Image.open(test_images[1]))
axs[1].imshow(patch_pred)
axs[0].set_title(test_images[1].split('/')[-1])
axs[1].set_title('Predicted marshes')
plt.show()
{% endraw %}

Third result

{% raw %}
patch_pred = unet_predict(test_images[2])

fig, axs = plt.subplots(1,2, figsize=(10,5),dpi=300)
for a in axs:
    a.set_yticks([])
    a.set_xticks([])
axs[0].imshow(PIL.Image.open(test_images[2]))
axs[1].imshow(patch_pred)
axs[0].set_title(test_images[2].split('/')[-1])
axs[1].set_title('Predicted marshes')
plt.show()
{% endraw %}

Fourth example.

{% raw %}
patch_pred = unet_predict(test_images[3])

fig, axs = plt.subplots(1,2, figsize=(10,5),dpi=300)
for a in axs:
    a.set_yticks([])
    a.set_xticks([])
axs[0].imshow(PIL.Image.open(test_images[3]))
axs[1].imshow(patch_pred)
axs[0].set_title(test_images[3].split('/')[-1])
axs[1].set_title('Predicted marshes')
plt.show()
{% endraw %}